from mynumpy import *
from util import *
from estimator_iid import estimator_iid

# we currently stratify FIRST DIMENSION of base estimator only

# def stratify1(omega,bin,bins):
#     omega2 = omega + 0.0
#     omega2[0,:] = omega2[0,:]/bins + bin/bins
#     return omega2

# we stratify ALL dimensions together?

def stratify1(omega,bin,bins):
    omega2 = omega + 0.0
    omega2 = omega2/bins + bin/bins
    return omega2

class estimator_stratified:
    def __init__(self,base_est,bins):
        self.base_est  = base_est
        self.bins      = bins
        self.label     = 'str'+str(bins)+'-'+self.base_est.label
        self.omega_dim = base_est.omega_dim*bins
        # re-use the IID estimator to save effort
        self.est_iid   = estimator_iid(base_est,bins)
        self.w_dim     = base_est.w_dim

    def sample_omega(self,num):
        #return self.base_est.sample_omega(num)
        #omega = [rand(num)/self.bins + k/self.bins for k in range(self.bins)]
        omega = rand(self.omega_dim,num)
        return omega

    def stratify(self,omegas):
        # do you really need to use an enumeration, subroutine, and array destructuring-restructuring all in one line, buddy?
        return np.vstack([stratify1(omega,bin,self.bins) for bin,omega in enumerate(np.split(omegas,self.bins))])

    def logR(self,omegas,w):
        # transform the omegas into stratified form
        return self.est_iid.logR(self.stratify(omegas),w)

    def sample_z(self,omegas,w):
        return self.est_iid.sample_z(self.stratify(omegas),w)

    def sample_zs(self,omegas,w):
        return self.est_iid.sample_zs(self.stratify(omegas),w)
